# coding=utf-8

import time
import pymysql
from pymysql.cursors import DictCursor


class SQLRequest:

    def __init__(self, host, user, password, port, charset, connect_timeout, db_type, sql):
        """
        :param host: 主机名
        :param user: 用户
        :param password: 密码
        :param port: 端口
        :param charset: 字符集
        :param connect_timeout: 超时时间 单位s
        :param db_type: shell命令
        :param sql: sql脚本
        """
        self.host = host
        self.user = user
        self.password = password
        self.port = port
        self.charset = charset
        self.connect_timeout = connect_timeout
        self.db_type = db_type
        self.sql = sql
        # 请求与响应
        self.response_headers = None  # type: dict
        self.response_body = None  # type: dict
        self.request_headers = None  # type: dict
        self.request_body = None  # type: dict
        self.affected_rows = 0  # 影响行数
        self.all_rows = ()  # 执行结果
        # 执行耗时 ms
        self.elapsed_time = 0

    def send(self):
        self.execute(query=self.sql)

    def execute(self, query, args=None):
        # 分割多行sql脚本
        sql_list = self._sql_split(query=query)
        # send request
        _start_clock = time.time()
        with self as cursor:
            for sql in sql_list:
                self.affected_rows = cursor.execute(query=sql, args=args)
                self.all_rows = cursor.fetchall()
        # recv response
        _end_clock = time.time()
        self.elapsed_time = int((_end_clock - _start_clock) * 1000) + 1
        self._handle_response()

    def _sql_split(self, query):
        """
        将多行sql文本分割为单行sql文本集合
        :param query: 单行/多行sql文本
        :type query: str
        :return: 分割后的sql列表
        :rtype: list
        """
        sql_list = []
        lines = []
        # 删除注释行 空白行
        for line in query.split('\n'):
            if str.strip(line).startswith("--"):
                continue
            elif str.strip(line) == '':
                continue
            else:
                lines.append(line)
        for line in '\n'.join(lines).split(';'):
            if '\n' in line:
                line = line.replace('\n', ' ')
            elif str.strip(line) == '':
                continue
            sql_list.append(line)
        return sql_list

    def __enter__(self):
        """
        :return: 游标
        :rtype: DictCursor
        """
        self.connection = pymysql.connect(
            host=self.host,
            user=self.user,
            password=self.password,
            port=self.port,
            charset=self.charset,
            connect_timeout=self.connect_timeout
        )
        self.cursor = self.connection.cursor(cursor=DictCursor)
        return self.cursor

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            self.connection.rollback()
        else:
            self.connection.commit()
        self.cursor.close()
        self.connection.close()

    def _handle_response(self):
        """处理应答解析拿到 应答体 应答头 请求体 请求头"""
        self._handle_response_headers()
        self._handle_request_body()
        self._handle_request_headers()
        self._handle_response_body()

    def _handle_response_headers(self):
        self.response_headers = {
            'host': self.host,
            'port': self.port,
            'user': self.user,
            'password': self.password,
            'connect_timeout': self.connect_timeout,
            'sql': self.sql,
            'charset': self.charset,
            'db_type': self.db_type,
            'elapsed_time': self.elapsed_time,
        }

    def _handle_request_body(self):
        self.request_body = {
            'sql': self.sql,
        }

    def _handle_request_headers(self):
        self.request_headers = {
            'host': self.host,
            'port': self.port,
            'user': self.user,
            'password': self.password,
            'connect_timeout': self.connect_timeout,
            'sql': self.sql,
            'charset': self.charset,
            'db_type': self.db_type,
        }

    def _handle_response_body(self):
        self.response_body = {
            'affected_rows': self.affected_rows,
            'all_rows': self.all_rows,
        }


def make_request(host, port, user, password, connect_timeout, db_type, charset, sql):
    """
    :param host: 主机名
    :param port: 端口
    :param user: 用户
    :param password: 密码
    :param connect_timeout: 超时时间
    :param db_type: shell命令
    :param charset: 字符集
    :param sql: sql脚本
    """
    return SQLRequest(host=host, port=int(port), user=user, password=password, connect_timeout=int(connect_timeout),
                      db_type=db_type, charset=charset, sql=sql)


if __name__ == '__main__':
    pass
